namespace HZY.Framework.Repository.EntityFramework.BulkCopys;

/// <summary>
/// sqlserver 批量拷贝封装
/// </summary>
public static class SqlServerBulkCopyExtensions
{
    /// <summary>
    /// sqlserver 批量拷贝数据
    /// </summary>
    /// <param name="database"></param>
    /// <param name="dataTable"></param>
    /// <param name="tableName"></param>
    /// <param name="dbTransaction"></param>
    public static void SqlServerBulkCopy(this DatabaseFacade database, DataTable dataTable, string tableName,
        IDbTransaction? dbTransaction)
    {
        if (!database.IsSqlServer())
        {
            throw new Exception("当前不是 SqlServer 数据库无法调用此函数!");
        }

        var dbConnection = (SqlConnection)database.GetDbConnection();

        var sqlBulkCopy = dbTransaction == null
            ? new SqlBulkCopy(dbConnection)
            : new SqlBulkCopy(dbConnection, SqlBulkCopyOptions.Default, (SqlTransaction)dbTransaction);

        sqlBulkCopy.DestinationTableName = tableName;
        sqlBulkCopy.BatchSize = dataTable.Rows.Count;
        foreach (DataColumn item in dataTable.Columns)
        {
            sqlBulkCopy.ColumnMappings.Add(item.ColumnName, item.ColumnName);
        }

        if (dbConnection.State != ConnectionState.Open)
        {
            dbConnection.Open();
        }

        try
        {
            sqlBulkCopy.WriteToServer(dataTable);
        }
        finally
        {
            if (dbConnection.State == ConnectionState.Open)
            {
                dbConnection.Close();
            }

            sqlBulkCopy.Close();
        }
    }

    /// <summary>
    /// sqlserver 批量拷贝数据
    /// </summary>
    /// <param name="database"></param>
    /// <param name="dataTable"></param>
    /// <param name="tableName"></param>
    /// <param name="dbTransaction"></param>
    public static async Task SqlServerBulkCopyAsync(this DatabaseFacade database, DataTable dataTable, string tableName,
        IDbTransaction? dbTransaction)
    {
        if (!database.IsSqlServer())
        {
            throw new Exception("当前不是 SqlServer 数据库无法调用此函数!");
        }

        var dbConnection = (SqlConnection)database.GetDbConnection();

        var sqlBulkCopy = dbTransaction == null
            ? new SqlBulkCopy(dbConnection)
            : new SqlBulkCopy(dbConnection, SqlBulkCopyOptions.Default, (SqlTransaction)dbTransaction);

        sqlBulkCopy.DestinationTableName = tableName;
        sqlBulkCopy.BatchSize = dataTable.Rows.Count;
        foreach (DataColumn item in dataTable.Columns)
        {
            sqlBulkCopy.ColumnMappings.Add(item.ColumnName, item.ColumnName);
        }

        if (dbConnection.State != ConnectionState.Open)
        {
            await dbConnection.OpenAsync();
        }

        try
        {
            await sqlBulkCopy.WriteToServerAsync(dataTable);
        }
        finally
        {
            if (dbConnection.State == ConnectionState.Open)
            {
                await dbConnection.CloseAsync();
            }

            sqlBulkCopy.Close();
        }
    }

    /// <summary>
    /// sqlserver 批量拷贝数据
    /// </summary>
    /// <param name="database"></param>
    /// <param name="items"></param>
    /// <param name="tableName"></param>
    /// <param name="dbTransaction"></param>
    /// <typeparam name="T"></typeparam>
    /// <returns></returns>
    public static void SqlServerBulkCopy<T>(this DatabaseFacade database, List<T> items, string? tableName,
        IDbTransaction? dbTransaction)
        where T : class, new()
    {
        var dataTable = items.ToDataTable();

        if (string.IsNullOrWhiteSpace(tableName))
        {
            var type = typeof(T);
            tableName = tableName ?? type.Name;
            var tableAttribute = type.GetTableAttribute();
            if (tableAttribute != null)
            {
                tableName = tableAttribute.Name;
            }
        }

        database.SqlServerBulkCopy(dataTable, tableName, dbTransaction);
    }

    /// <summary>
    /// sqlserver 批量拷贝数据
    /// </summary>
    /// <param name="database"></param>
    /// <param name="items"></param>
    /// <param name="tableName"></param>
    /// <param name="dbTransaction"></param>
    /// <typeparam name="T"></typeparam>
    /// <returns></returns>
    public static Task SqlServerBulkCopyAsync<T>(this DatabaseFacade database, List<T> items, string? tableName,
        IDbTransaction? dbTransaction)
        where T : class, new()
    {
        var dataTable = items.ToDataTable();

        if (!string.IsNullOrWhiteSpace(tableName))
            return database.SqlServerBulkCopyAsync(dataTable, tableName, dbTransaction);

        var type = typeof(T);
        tableName = tableName ?? type.Name;
        var tableAttribute = type.GetTableAttribute();
        if (tableAttribute != null)
        {
            tableName = tableAttribute.Name;
        }

        return database.SqlServerBulkCopyAsync(dataTable, tableName, dbTransaction);
    }
}